{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 👾 PixelCNN using Tensorflow distributions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own PixelCNN on the fashion MNIST dataset using Tensorflow distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import datasets, layers, models, optimizers, callbacks\n",
    "import tensorflow_probability as tfp\n",
    "\n",
    "from notebooks.utils import display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = 32\n",
    "N_COMPONENTS = 5\n",
    "EPOCHS = 10\n",
    "BATCH_SIZE = 128"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7716fac-0010-49b0-b98e-53be2259edde",
   "metadata": {},
   "source": [
    "## 1. Prepare the data <a name=\"prepare\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a73e5a4-1638-411c-8d3c-29f823424458",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "(x_train, _), (_, _) = datasets.fashion_mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebae2f0d-59fd-4796-841f-7213eae638de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the data\n",
    "\n",
    "\n",
    "def preprocess(imgs):\n",
    "    imgs = np.expand_dims(imgs, -1)\n",
    "    imgs = tf.image.resize(imgs, (IMAGE_SIZE, IMAGE_SIZE)).numpy()\n",
    "    return imgs\n",
    "\n",
    "\n",
    "input_data = preprocess(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show some items of clothing from the training set\n",
    "display(input_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 2. Build the PixelCNN <a name=\"build\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71a2a4a1-690e-4c94-b323-86f0e5b691d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a Pixel CNN network\n",
    "dist = tfp.distributions.PixelCNN(\n",
    "    image_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),\n",
    "    num_resnet=1,\n",
    "    num_hierarchies=2,\n",
    "    num_filters=32,\n",
    "    num_logistic_mix=N_COMPONENTS,\n",
    "    dropout_p=0.3,\n",
    ")\n",
    "\n",
    "# Define the model input\n",
    "image_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))\n",
    "\n",
    "# Define the log likelihood for the loss fn\n",
    "log_prob = dist.log_prob(image_input)\n",
    "\n",
    "# Define the model\n",
    "pixelcnn = models.Model(inputs=image_input, outputs=log_prob)\n",
    "pixelcnn.add_loss(-tf.reduce_mean(log_prob))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35b14665-4359-447b-be58-3fd58ba69084",
   "metadata": {},
   "source": [
    "## 3. Train the PixelCNN <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9ec362d-41fa-473a-ad56-ebeec6cfd3b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compile and train the model\n",
    "pixelcnn.compile(\n",
    "    optimizer=optimizers.Adam(0.001),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c525e44b-b3bb-489c-9d35-fcfe3e714e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "\n",
    "class ImageGenerator(callbacks.Callback):\n",
    "    def __init__(self, num_img):\n",
    "        self.num_img = num_img\n",
    "\n",
    "    def generate(self):\n",
    "        return dist.sample(self.num_img).numpy()\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        generated_images = self.generate()\n",
    "        display(\n",
    "            generated_images,\n",
    "            n=self.num_img,\n",
    "            save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
    "        )\n",
    "\n",
    "\n",
    "img_generator_callback = ImageGenerator(num_img=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd6a5a71-eb55-4ec0-9c8c-cb11a382ff90",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pixelcnn.fit(\n",
    "    input_data,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    epochs=EPOCHS,\n",
    "    verbose=True,\n",
    "    callbacks=[tensorboard_callback, img_generator_callback],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb1f295f-ade0-4040-a6a5-a7b428b08ebc",
   "metadata": {},
   "source": [
    "## 4. Generate images <a name=\"generate\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db3cfe3-339e-463d-8af5-fbd403385fca",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_images = img_generator_callback.generate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80087297-3f47-4e0c-ac89-8758d4386d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(generated_images, n=img_generator_callback.num_img)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
